﻿#include "precompiled.h"
#include "common.h"
#include "Transform.h"

#include "Entity.h"
#include "MeshRenderer.h"

using namespace DirectX;

namespace RTCam {

Transform::Transform(const shared_ptr<Entity>& entity) :
	Component(entity),
	m_localMatrixNeedsUpdate(true),
	m_globalMatrixNeedsUpdate(true),
	m_parentNeedsUpdate(true),
	m_childrenNeedUpdate(true),
	m_localPosition(0, 0, 0),
	m_localScale(1, 1, 1),
	m_localRotation(0, 0, 0, 1),
	m_cachedWorldPosition(0, 0, 0),
	m_cachedWorldScale(1, 1, 1),
	m_cachedWorldRotation(0, 0, 0, 1)
{
}

Transform::~Transform(void)
{
}

void Transform::UpdateWorldAndNormalMatrices()
{
	auto entity = m_entity.lock();
	ASSERT(entity != nullptr);

	//// SUPER HACKY TEST CODE
	//XMVECTOR rotation = XMLoadFloat4A(&GetLocalRotation());
	////XMVECTOR additionalRotation = XMQuaternionRotationRollPitchYaw(0.02f, 0.01f, 0);
	//XMVECTOR additionalRotation = XMQuaternionRotationRollPitchYaw(0, 0.01f, 0);
	//rotation = XMQuaternionMultiply(rotation, additionalRotation);
	//XMFLOAT4A newRotation;
	//XMStoreFloat4A(&newRotation, rotation);
	//SetLocalRotation(newRotation);
	//// END TEST CODE

	XMMATRIX globalTransform = GetGlobalTransformMatrix();

	// The normals are transformed by the transpose of the inverse of the linear transformation (scale and rotation).
	// Set the translation to 0 and get the inverse matrix.
	// (The matrix isn't transposed as it'll have to be transposed again when copying to the constant buffer).
	globalTransform.r[3] = XMVectorSet(0, 0, 0, 1);
	XMVECTOR determinant;
	globalTransform = XMMatrixInverse(&determinant, globalTransform);
	ASSERT(XMVectorGetX(determinant) != 0); // Make sure the matrix inverted successfully

	XMStoreFloat4x4A(&entity->m_modelCBufferData.Normal, globalTransform);
}

void Transform::UpdateWorldViewProjectionMatrix(CXMMATRIX viewProjection, bool resetCameraMotion)
{
	auto entity = m_entity.lock();
	ASSERT(entity != nullptr);
	ASSERT_MSG(!m_globalMatrixNeedsUpdate, "UpdateCBufferWorldMatrix should have been called earlier this frame");

	XMMATRIX worldTransform = GetGlobalTransformMatrix();
	XMMATRIX worldViewProjection = XMMatrixMultiply(worldTransform, viewProjection);

	// Calculate the inverse world-view-projection matrix (for reconstructing position from the depth buffer)
	XMVECTOR determinant;
	XMMATRIX invWorldViewProjection = XMMatrixInverse(&determinant, worldViewProjection);
	ASSERT(XMVectorGetX(determinant) != 0); // Make sure the matrix inverted successfully


	// Make the previous world-view-projection matrix the same as the current if the camera's motion has been reset.
	// Otherwise, copy the previous frame's world-view-projection matrix
	if(resetCameraMotion) {
		XMStoreFloat4x4A(&entity->m_modelCBufferData.PrevWorldViewProjection, XMMatrixTranspose(worldViewProjection));
	} else {
		entity->m_modelCBufferData.PrevWorldViewProjection = entity->m_modelCBufferData.WorldViewProjection;
	}

	// Update the world-view-projection matrix
	XMStoreFloat4x4A(&entity->m_modelCBufferData.WorldViewProjection, XMMatrixTranspose(worldViewProjection));

	// Store the inverse world-view-projection matrix
	XMStoreFloat4x4A(&entity->m_modelCBufferData.InvWorldViewProjection, XMMatrixTranspose(invWorldViewProjection));
	
}

//--------------------------------------------------------------------------------------
// World space transformations

XMFLOAT3A Transform::GetForward() const
{
	// TODO: Cache the decomposed vectors

	// Get the global rotation
	XMVECTOR scale, rotQuat, translation;
	XMMatrixDecompose(&scale, &rotQuat, &translation, GetGlobalTransformMatrix());

	XMFLOAT4X4A gRotation;
	XMStoreFloat4x4A(&gRotation, XMMatrixRotationQuaternion(rotQuat));

	// Return the z basis
	return XMFLOAT3A(gRotation.m[2]);
}

XMFLOAT3A Transform::GetRight() const
{
	// Get the global rotation
	XMVECTOR scale, rotQuat, translation;
	XMMatrixDecompose(&scale, &rotQuat, &translation, GetGlobalTransformMatrix());

	XMFLOAT4X4A gRotation;
	XMStoreFloat4x4A(&gRotation, XMMatrixRotationQuaternion(rotQuat));

	// Return the x basis
	return XMFLOAT3A(gRotation.m[0]);
}

XMFLOAT3A Transform::GetUp() const
{
	// Get the global rotation
	XMVECTOR scale, rotQuat, translation;
	XMMatrixDecompose(&scale, &rotQuat, &translation, GetGlobalTransformMatrix());

	XMFLOAT4X4A gRotation;
	XMStoreFloat4x4A(&gRotation, XMMatrixRotationQuaternion(rotQuat));

	// Return the y basis
	return XMFLOAT3A(gRotation.m[1]);
}

XMVECTOR Transform::GetForwardVector() const
{
	return XMLoadFloat3A(&GetForward());
}

XMVECTOR Transform::GetRightVector() const
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

XMVECTOR Transform::GetUpVector() const
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

DirectX::XMFLOAT3A Transform::GetWorldPosition() const
{
	// XMMatrixDecompose
	THROW_UNIMPLEMENTED_EXCEPTION();
}

DirectX::XMFLOAT4A Transform::GetWorldRotation() const
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

DirectX::XMFLOAT3A Transform::GetWorldScale() const
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

void Transform::SetWorldPosition( const DirectX::XMFLOAT3A& position )
{
	// localTransform = goalTransform * Inverse(parentTransform)
	THROW_UNIMPLEMENTED_EXCEPTION();
}

void Transform::SetWorldRotation( const DirectX::XMFLOAT4A& rotation )
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

//--------------------------------------------------------------------------------------
// Local transformations

DirectX::XMFLOAT3A Transform::GetLocalPosition() const
{
	return m_localPosition;
}

DirectX::XMFLOAT3A Transform::GetLocalScale() const
{
	return m_localScale;
}

DirectX::XMFLOAT4A Transform::GetLocalRotation() const	
{
	return m_localRotation;
}

void Transform::SetLocalPosition(const DirectX::XMFLOAT3A& position)
{
	m_localPosition = position;
	InvalidateLocalTransformMatrix();
}

void Transform::SetLocalRotation(const DirectX::XMFLOAT4A& rotation)
{
	m_localRotation = rotation;
	InvalidateLocalTransformMatrix();
}

void Transform::SetLocalScale(const DirectX::XMFLOAT3A& scale)
{
	m_localScale = scale;
	InvalidateLocalTransformMatrix();
}

//--------------------------------------------------------------------------------------


XMMATRIX Transform::UpdateLocalTransform() const
{
	XMVECTOR position = XMLoadFloat3A(&m_localPosition);
	XMVECTOR rotation = XMLoadFloat4A(&m_localRotation);

	XMMATRIX localTransform = XMMatrixRotationQuaternion(rotation);

	localTransform.r[0] = XMVectorScale(localTransform.r[0], m_localScale.x);
	localTransform.r[1] = XMVectorScale(localTransform.r[1], m_localScale.y);
	localTransform.r[2] = XMVectorScale(localTransform.r[2], m_localScale.z);

	localTransform.r[3] = XMVectorSetW(position, 1.0f);

	XMStoreFloat4x4A(&m_cachedLocalTransformMatrix, localTransform);
	m_localMatrixNeedsUpdate = false;

	return localTransform;
}

XMMATRIX Transform::UpdateGlobalTransform() const
{
	XMMATRIX globalTransform;

	auto parent = m_entity.lock()->GetParent();
	if(parent == nullptr) {
		// No parent, global = local
		globalTransform = GetLocalTransformMatrix();
	} else {
		// Stack the parent's global transform on top of the local transform.
		XMMATRIX parentTransform = parent->GetTransform()->GetGlobalTransformMatrix();
		XMMATRIX localTransform = GetLocalTransformMatrix();
		globalTransform = XMMatrixMultiply(localTransform, parentTransform);
	}

	XMStoreFloat4x4A(&m_cachedGlobalTransformMatrix, globalTransform);
	m_globalMatrixNeedsUpdate = false;

	return globalTransform;
}

XMFLOAT4X4A Transform::GetLocalTransform() const
{
	if(m_localMatrixNeedsUpdate) {
		UpdateLocalTransform();
	}

	return m_cachedLocalTransformMatrix;
}

DirectX::XMMATRIX Transform::GetLocalTransformMatrix() const
{
	if(m_localMatrixNeedsUpdate) {
		return UpdateLocalTransform();
	} else {
		return XMLoadFloat4x4A(&m_cachedLocalTransformMatrix);
	}
}

XMFLOAT4X4A Transform::GetGlobalTransform() const
{
	if(m_globalMatrixNeedsUpdate) {
		UpdateGlobalTransform();
	}
	return m_cachedGlobalTransformMatrix;
}

DirectX::XMMATRIX Transform::GetGlobalTransformMatrix() const
{
	if(m_globalMatrixNeedsUpdate) {
		return UpdateGlobalTransform();
	} else {
		return XMLoadFloat4x4A(&m_cachedGlobalTransformMatrix);
	}
}

//--------------------------------------------------------------------------------------

void Transform::InvalidateLocalTransformMatrix()
{
#ifdef _DEBUG
	ZeroMemory(&m_cachedLocalTransformMatrix, sizeof(m_cachedLocalTransformMatrix));
#endif
	
	// This transform's local matrix will need to be updated.
	m_localMatrixNeedsUpdate = true;

	// The global matrix of this transform and any children will also need to be updated.
	InvalidateGlobalTransformMatrix();
}

void Transform::InvalidateGlobalTransformMatrix()
{
#ifdef _DEBUG
	ZeroMemory(&m_cachedGlobalTransformMatrix, sizeof(m_cachedGlobalTransformMatrix));
#endif

	// The global matrix of this transform and any children will need to be updated.
	m_globalMatrixNeedsUpdate = true;
	for(auto& child: m_entity.lock()->GetChildren()) {
		child->GetTransform()->InvalidateGlobalTransformMatrix();
	}
}

//--------------------------------------------------------------------------------------

XMFLOAT3A Transform::TransformPointLocalToWorld( const XMFLOAT3A& position )
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

XMFLOAT3A Transform::TransformPointWorldToLocal( const XMFLOAT3A& position )
{
	THROW_UNIMPLEMENTED_EXCEPTION();
}

} // end namespace